# -*- coding: utf-8 -*-
"""
Code to generate the model behaviour figures in Figure S1 (Supplementary Information).
Comment/uncomment each parameter set as desired.

Created on Thu May  2 10:56:56 2024
@author: Andrei Sontag
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.patches import Rectangle


N = 100

arl = 0.1
brl = 0
crl = 0.1
ars = 0
brs = 3
crs = 5
asr = 0.2
bsr = 1
csr = 2

# =============================================================================
# # Behaviour 1
# arl = 0
# brl = 0
# crl = 0
# ars = .6
# brs = 3
# crs = 3
# asr = 0.2
# bsr = 5
# csr = 0
# =============================================================================

# =============================================================================
# # Behaviour 2
# arl = 1.5
# brl = 1
# crl = 0
# ars = 0
# brs = 3
# crs = 5
# asr = 0.2
# bsr = 5
# csr = 0
# =============================================================================

# =============================================================================
# # Behaviour 3
# arl = 0
# brl = 1
# crl = 0
# ars = 0
# brs = 3
# crs = 5
# asr = 0.2
# bsr = 5
# csr = 0
# =============================================================================

# =============================================================================
# # Behaviour 4 (striaght manifold)
# arl = 0
# brl = 0
# crl = 0
# ars = 10
# brs = 1
# crs = .1
# asr = 0.01
# bsr = 31
# csr = 0
# =============================================================================

# =============================================================================
# # Behaviour 5
# arl = 0
# brl = 0
# crl = 0
# ars = .1
# brs = .01
# crs = 10
# asr = 1
# bsr = .31
# csr = 0
# =============================================================================

# =============================================================================
# # Behaviour 6
# arl = 0.6
# brl = 0
# crl = 0
# ars = 0
# brs = 3
# crs = 3
# asr = .2
# bsr = 5
# csr = 0 
# =============================================================================

# =============================================================================
# # Behaviour 7
# arl = 0
# brl = 0
# crl = 0
# ars = .6
# brs = 3
# crs = 5
# asr = .2
# bsr = 5
# csr = 2.01
# =============================================================================

# Behaviour 8
arl = 0
brl = 0
crl = 0.1
ars = .6
brs = 3
crs = 5
asr = .2
bsr = 1
csr = 0

A = -arl-ars
B = -crs
C = -crl-brs+bsr
D = arl
E = asr
F = crl+csr
  
s=(D-A)/(C-F)
zst = np.sqrt((1-s)**2 + 4*(D*(1-s) + E*s + F*s*(1-s))/B )
 
c1 = ars + 2*arl
c2 = bsr - brs - 2*crl - csr
c3 = ars
c4 = 2*asr
c5 = crs/2
c6 = bsr - brs + csr

# 1D arrays
z = np.linspace(-2,2,2001)
w = np.linspace(-2,2,2001)
    
# Meshgrid
Z,W = np.meshgrid(z,w)
  
# Assign vector directions
Ez = (-c1*Z + c2*Z*(1-W))*(W**2 > Z**2)*(W < 1)
Ew = (-c3*W + c4*(1-W) - c5*(W**2-Z**2) + c6*W*(1-W))*(W**2 > Z**2)*(W < 1)

 
plt.figure(figsize=(9.5,18))
matplotlib.rcParams.update({'font.size': 40})

linewidth = 5

ax = plt.axes()
ax.add_patch(Rectangle((-2, -2), 4, 4, fill=False, hatch='///',color='#D3D3D3'))

points = [[0, 0], [1, 1], [-1, 1]]
triangle = np.array(points+ points[:1])

plt.fill(triangle[:, 0], triangle[:, 1], color='white', alpha=1)
plt.streamplot(Z,W,Ez,Ew, start_points = [[0.3,0.999],[-0.1,0.1]],color='#CC7722',arrowstyle='-|>', linewidth=linewidth)
plt.streamplot(Z,W,Ez,Ew, start_points = [[0.1,0.1],[-0.9,0.999],[-0.05,0.999]],color='#CC7722',arrowstyle='-|>', linewidth=linewidth)

plt.plot(z,[1]*len(z),color='black',linestyle='-',linewidth=2)
plt.plot(z,z,color='black',linestyle='-',linewidth=2)
plt.plot(z,-z,color='black',linestyle='-',linewidth=2)
plt.plot(z,[0]*len(z),'k-',linewidth=2)

b1 = c3+c4-c6
a1 = (c5+c6)

zplot = np.linspace(-2,2,101)
plt.plot(zplot, (-b1 + np.sqrt(b1**2 + 4*a1*(c4+c5*zplot**2)))/(2*a1),color='royalblue',linewidth=linewidth)
plt.plot(zplot, (-b1 - np.sqrt(b1**2 + 4*a1*(c4+c5*zplot**2)))/(2*a1),color='royalblue',label=r'$dw/dt = 0$',linewidth=linewidth)
plt.plot([0]*len(w),w,color='red',label=r'$dz/dt = 0$',linewidth=linewidth)
plt.plot(z,[1-s]*len(z),color='red',linewidth=linewidth)
plt.plot(zst,1-s,marker='*',markersize=30,markeredgecolor='black',markerfacecolor='black')
plt.plot(-zst,1-s,marker='*',markersize=30,markeredgecolor='black',markerfacecolor='black')

plt.axis([-1.2,1.2,-0.4,1.3])

ax.set_yticks([0,0.5,1])
ax.set_xticks([-1,0,1])

# 1D arrays
z = np.linspace(-2,2,21)
w = np.linspace(-2,2,21)
    
# Meshgrid
Z,W = np.meshgrid(z,w)
  
# Assign vector directions
Ez = (-c1*Z + c2*Z*(1-W))
Ew = -c3*W + c4*(1-W) - c5*(W**2-Z**2) + c6*W*(1-W)

plt.quiver(Z,W,Ez,Ew,scale=40,width=1/200)
plt.tight_layout()